import sys
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from models.sde import init_sde
from models.samplers import cond_ode_likelihood, pose_cond_ode_sampler, joint_cond_ode_sampler, cond_pc_sampler
from ipdb import set_trace
from models.metrics import get_metrics, get_rot_matrix
from models.misc import average_quaternion_batch, exists_or_mkdir
from models.archs.encoders.pointnet2 import Pointnet2ClsMSG, SegNet
import pytorch3d


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_pose_dim(rot_mode):
    assert rot_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'], \
        f"the rotation mode {rot_mode} is not supported!"

    if rot_mode == 'quat_wxyz' or rot_mode == 'quat_xyzw':
        pose_dim = 7
    elif rot_mode == 'euler_xyz':
        pose_dim = 6
    elif rot_mode == 'euler_xyz_sx_cx' or rot_mode == 'rot_matrix':
        pose_dim = 9
    else:
        raise NotImplementedError
    return pose_dim

class RotHead(nn.Module):
    def __init__(self, in_feat_dim, out_dim=3):
        super(RotHead, self).__init__()
        self.f = in_feat_dim
        self.k = out_dim

        self.conv1 = torch.nn.Conv1d(self.f, 1024, 1)
        self.conv2 = torch.nn.Conv1d(1024, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 256, 1)
        self.conv4 = torch.nn.Conv1d(256, self.k, 1)
        self.drop1 = nn.Dropout(0.2)
        self.bn1 = nn.BatchNorm1d(1024)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(256)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))

        x = torch.max(x, 2, keepdim=True)[0]

        x = F.relu(self.bn3(self.conv3(x)))
        x = self.drop1(x)
        x = self.conv4(x)

        x = x.squeeze(2)
        x = x.contiguous()

        return x

class TransHead(nn.Module):
    def __init__(self, in_feat_dim, out_dim=3):
        super(TransHead, self).__init__()
        self.f = in_feat_dim
        self.k = out_dim

        self.conv1 = torch.nn.Conv1d(self.f, 1024, 1)

        self.conv2 = torch.nn.Conv1d(1024, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 256, 1)
        self.conv4 = torch.nn.Conv1d(256, self.k, 1)
        self.drop1 = nn.Dropout(0.2)
        self.bn1 = nn.BatchNorm1d(1024)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(256)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))

        x = torch.max(x, 2, keepdim=True)[0]

        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.drop1(x)
        x = self.conv4(x)

        x = x.squeeze(2)
        x = x.contiguous()
        return x

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def weight_init(shape, mode, fan_in, fan_out):
    if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
    if mode == 'xavier_normal':  return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
    if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
    if mode == 'kaiming_normal':  return np.sqrt(1 / fan_in) * torch.randn(*shape)
    raise ValueError(f'Invalid init mode "{mode}"')


class Dense(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.dense(x)[..., None, None]


class Linear(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
        self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
        self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None

    def forward(self, x):
        x = x @ self.weight.to(x.dtype).t()
        if self.bias is not None:
            x = x.add_(self.bias.to(x.dtype))
        return x


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        
        
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class PoseScoreNet(nn.Module):
    def __init__(self, specs, marginal_prob_func, pose_mode='quat_wxyz', regression_head='RT', per_point_feature=False):
        """_summary_

        Args:
            marginal_prob_func (func): marginal_prob_func of score network
            pose_mode (str, optional): the type of pose representation from {'quat_wxyz', 'quat_xyzw', 'rot_matrix', 'euler_xyz'}. Defaults to 'quat_wxyz'.
            regression_head (str, optional): _description_. Defaults to 'RT'.

        Raises:
            NotImplementedError: _description_
        """
        super(PoseScoreNet, self).__init__()
        self.specs = specs
        self.regression_head = regression_head
        self.per_point_feature = per_point_feature
        self.act = nn.ReLU(True)
        self.in_channels = self.specs['in_channels']
        self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = init_sde(self.specs['sde_mode'])
        pose_dim = get_pose_dim(pose_mode)
        ''' encode pts feature'''
        modules = []
        
        hidden_dims = [256, 256, 256, 256, 256]
        
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(self.in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            self.in_channels = h_dim
        self.encoder = nn.Sequential(*modules)


        self.pts_encoder = Pointnet2ClsMSG(0)
        self.seg_encoder = SegNet(self.specs['num_parts'])

        ''' encode pose '''
        self.pose_encoder = nn.Sequential(
            nn.Linear(pose_dim, 256),
            self.act,
            nn.Linear(256, 256),
            self.act,
        )

        ''' encode joint '''
        self.joint_encoder = nn.Sequential(
            nn.Linear(6 * (specs['num_parts']-1), 256),
            self.act,
            nn.Linear(256, 256),
            self.act,
        )

        ''' encode t '''
        self.t_encoder = nn.Sequential(
            GaussianFourierProjection(embed_dim=128),
            
            nn.Linear(128, 128),
            self.act,
        )

        
        ''' joint_xyz regress head '''
        self.joint_xyz_head = nn.Sequential(
            nn.Linear(128 + 256 + 1024, 256),
            
            self.act,
            zero_module(nn.Linear(256, 3 * (self.specs['num_parts']-1))),
        )
        ''' joint_rpy regress head '''
        self.joint_rpy_head = nn.Sequential(
            nn.Linear(128 + 256 + 1024, 256),
            
            self.act,
            zero_module(nn.Linear(256, 3 * (self.specs['num_parts']-1))),
        )

        
        self.seg_layer = nn.Sequential(
            nn.Conv1d(1024, 512, kernel_size=1, padding=0),
            nn.Conv1d(512, 256, kernel_size=1, padding=0),
            nn.Conv1d(256, 128, kernel_size=1, padding=0),
            nn.Conv1d(128, self.specs['num_parts'], kernel_size=1, padding=0),
        )

        ''' fusion tail '''
        if self.regression_head == 'RT':
            self.fusion_tail = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 512),
                self.act,
                zero_module(nn.Linear(512, pose_dim)),
            )

        elif self.regression_head == 'R_and_T':
            ''' rotation regress head '''
            self.fusion_tail_rot = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                zero_module(nn.Linear(256, pose_dim - 3)),
            )

            ''' tranalation regress head '''
            self.fusion_tail_trans = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                zero_module(nn.Linear(256, 3)),
            )

        elif self.regression_head == 'Rx_Ry_and_T':
            if pose_mode != 'rot_matrix':
                raise NotImplementedError
            if per_point_feature:
                self.fusion_tail_rot_x = RotHead(in_feat_dim=128 + 256 + 1280,
                                                 out_dim=3)  
                self.fusion_tail_rot_y = RotHead(in_feat_dim=128 + 256 + 1280,
                                                 out_dim=3)  
                self.fusion_tail_trans = TransHead(in_feat_dim=128 + 256 + 1280,
                                                   out_dim=3)  
            else:
                ''' rotation_x_axis regress head '''
                self.fusion_tail_rot_x = nn.Sequential(
                    nn.Linear(128 + 256 + 1024, 256),
                    
                    self.act,
                    zero_module(nn.Linear(256, 3)),
                )
                self.fusion_tail_rot_y = nn.Sequential(
                    nn.Linear(128 + 256 + 1024, 256),
                    
                    self.act,
                    zero_module(nn.Linear(256, 3)),
                )

                ''' tranalation regress head '''
                self.fusion_tail_trans = nn.Sequential(
                    nn.Linear(128 + 256 + 1024, 256),
                    
                    self.act,
                    zero_module(nn.Linear(256, 3)),
                )
        else:
            raise NotImplementedError

        self.marginal_prob_func = marginal_prob_func

    def collect_score_loss(self, specs, data, teacher_model=None, pts_feat_teacher=None):
        '''
        Args:
            data, dict {
                'pts': [bs, c]
                'gt_pose': [bs, pose_dim]
            }
        '''
        gf_loss = 0
        self.specs = specs

        for _ in range(self.specs['repeat_num']):
            gf_loss += self.loss_fn(
                model=self,
                data=data,
                marginal_prob_func=self.marginal_prob_fn,
                sde_fn=self.sde_fn,
                likelihood_weighting=self.specs['likelihood_weighting'],
                teacher_model=teacher_model,
                pts_feat_teacher=pts_feat_teacher
            )
        gf_loss /= self.specs['repeat_num']
        losses = gf_loss
        return losses

    def collect_joint_loss(self, specs, data, teacher_model=None, pts_feat_teacher=None):
        '''
        Args:
            data, dict {
                'pts': [bs, c]
                'gt_pose': [bs, pose_dim]
            }
        '''
        gf_loss = 0
        self.specs = specs
        self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = init_sde(self.specs['sde_mode'])
        for _ in range(self.specs['repeat_num']):
            gf_loss += self.loss_joint_fn(
                model=self,
                data=data,
                marginal_prob_func=self.marginal_prob_fn,
                sde_fn=self.sde_fn,
                likelihood_weighting=self.specs['likelihood_weighting'],
                teacher_model=teacher_model,
                pts_feat_teacher=pts_feat_teacher
            )
        gf_loss /= self.specs['repeat_num']
        losses = gf_loss
        return losses

    def loss_fn(self,
            model,
            data,
            marginal_prob_func,
            sde_fn,
            eps=1e-5,
            likelihood_weighting=False,
            teacher_model=None,
            pts_feat_teacher=None
    ):
        data['type'] = 'pose'
        pts = data['zero_mean_pts']  
        gt_pose = data['zero_mean_gt_pose']  

        ''' get std '''
        bs = pts.shape[0]  
        random_t = torch.rand(bs, device=device) * (1. - eps) + eps  
        random_t = random_t.unsqueeze(-1)  
        mu, std = marginal_prob_func(gt_pose,
                                     random_t)  
        std = std.view(-1, 1)  

        ''' perturb data and get estimated score '''
        z = torch.randn_like(gt_pose)  
        perturbed_x = mu + z * std  
        data['sampled_pose'] = perturbed_x
        data['t'] = random_t  
        estimated_score = model(data)  

        ''' get target score '''
        if teacher_model is None:
            
            target_score = - z * std / (
                        std ** 2)  
        else:
            
            pts_feat_student = data['pts_feat'].clone()
            data['pts_feat'] = pts_feat_teacher
            target_score = teacher_model(data)
            data['pts_feat'] = pts_feat_student

        ''' loss weighting '''
        loss_weighting = std ** 2
        loss_ = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1),
                                     dim=-1))  

        return loss_

    def loss_joint_fn(self,
            model,
            data,
            marginal_prob_func,
            sde_fn,
            eps=1e-5,
            likelihood_weighting=False,
            teacher_model=None,
            pts_feat_teacher=None
    ):
        data['type'] = 'joint'
        pts = data['zero_mean_pts']  
        gt_xyz = data['zero_mean_gt_xyz']  
        gt_rpy = data['zero_mean_gt_rpy']  
        gt_joint = torch.cat((gt_xyz, gt_rpy), dim=-1)

        ''' get std '''
        bs = pts.shape[0]  
        random_t = torch.rand(bs, device=device) * (1. - eps) + eps  
        random_t = random_t.unsqueeze(-1)  
        mu, std = marginal_prob_func(gt_joint,
                                     random_t)  
        std = std.view(-1, 1)  

        ''' perturb data and get estimated score '''
        z = torch.randn_like(gt_joint)  
        perturbed_x = mu + z * std  
        data['sampled_joint'] = perturbed_x
        data['t'] = random_t  
        estimated_score = model(data)  

        ''' get target score '''
        if teacher_model is None:
            
            target_score = - z * std / (
                        std ** 2)  
        else:
            
            pts_feat_student = data['pts_feat'].clone()
            data['pts_feat'] = pts_feat_teacher
            target_score = teacher_model(data)
            data['pts_feat'] = pts_feat_student

        ''' loss weighting '''
        loss_weighting = std ** 2
        loss_ = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1),
                                     dim=-1))  

        return loss_

    def compute_miou_loss(self, pred_seg_per_point, gt_seg_onehot):
        dot = torch.sum(pred_seg_per_point * gt_seg_onehot, axis=1)
        denominator = torch.sum(pred_seg_per_point, axis=1) + torch.sum(gt_seg_onehot, axis=1) - dot
        mIoU = dot / (denominator + 1e-10)
        return torch.mean(1.0 - mIoU)

    def forward(self, data):
        '''
        Args:
            data, dict {
                'pts_feat': [bs, c]
                'pose_sample': [bs, pose_dim]
                't': [bs, 1]
            }
        '''
        if data['type'] == 'pose':
            pts_feat = data['pts_feat']
            sampled_pose = data['sampled_pose']
            t = data['t']

            t_feat = self.t_encoder(t.squeeze(1))  
            pose_feat = self.pose_encoder(sampled_pose)  

            if self.per_point_feature:
                num_pts = pts_feat.shape[-1]
                t_feat = t_feat.unsqueeze(-1).repeat(1, 1, num_pts)
                pose_feat = pose_feat.unsqueeze(-1).repeat(1, 1, num_pts)
                total_feat = torch.cat([pts_feat, t_feat, pose_feat], dim=1)
            else:
                total_feat = torch.cat([pts_feat, t_feat, pose_feat], dim=-1)  
            _, std = self.marginal_prob_func(total_feat, t)  

            
            if self.regression_head == 'RT':  
                out_score = self.fusion_tail(total_feat) / (std + 1e-7)  
            elif self.regression_head == 'R_and_T':  
                rot = self.fusion_tail_rot(total_feat)
                trans = self.fusion_tail_trans(total_feat)
                out_score = torch.cat([rot, trans], dim=-1) / (std + 1e-7)  
            elif self.regression_head == 'Rx_Ry_and_T':  
                rot_x = self.fusion_tail_rot_x(total_feat)
                rot_y = self.fusion_tail_rot_y(total_feat)
                trans = self.fusion_tail_trans(total_feat)
                out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std + 1e-7)  

            else:
                raise NotImplementedError
            
            return out_score

        elif data['type'] == 'joint':
            pts_feat = data['pts_feat']
            sampled_joint = data['sampled_joint']
            t = data['t']
            t_feat = self.t_encoder(t.squeeze(1))  
            pose_feat = self.joint_encoder(sampled_joint.float())  

            total_feat = torch.cat([pts_feat, t_feat, pose_feat], dim=-1)  
            _, std = self.marginal_prob_func(total_feat, t)  

            joint_xyz = self.joint_xyz_head(total_feat)
            joint_rpy = self.joint_rpy_head(total_feat)
            out_score = torch.cat([joint_xyz, joint_rpy], dim=-1) / (std + 1e-7)  

            
            return out_score

    def pred_pose_func(self, data, repeat_num, save_path='./visualization_results', return_average_res=False, init_x=None,
                  T0=None, return_process=False):

        self.is_testing = True

        with torch.no_grad():
            pts_feat = data['pts_feat']
            
            
            data['pts_feat'] = pts_feat
            bs = data['pts'].shape[0]
            self.pts_feature = True

            ''' Repeat input data, [bs, ...] to [bs*repeat_num, ...] '''
            repeated_data = {}
            for key in data.keys():
                if key in ['seg', 'atc', 'file_name', 'type']:
                    continue
                data_shape = [item for item in data[key].shape]
                repeat_list = np.ones(len(data_shape) + 1, dtype=np.int8).tolist()
                repeat_list[1] = repeat_num
                repeated_data[key] = data[key].unsqueeze(1).repeat(repeat_list)
                data_shape[0] = bs * repeat_num
                repeated_data[key] = repeated_data[key].view(data_shape)  

            repeated_init_x = None if init_x is None else init_x.unsqueeze(1).repeat(1, repeat_num, 1).view(
                bs * repeat_num, -1)
            repeated_data['type'] = data['type']
            ''' Inference '''
            in_process_sample, res = self.sample(repeated_data, 'ode', init_x=repeated_init_x, T0=T0)
            pred_pose = res.reshape(bs, repeat_num, -1)
            in_process_sample = in_process_sample.reshape(bs, repeat_num, in_process_sample.shape[1], -1)

            self.pts_feature = False

            ''' Calculate the average results '''
            if return_average_res:
                rot_matrix = get_rot_matrix(res[:, :-3], self.specs['pose_mode'])
                quat_wxyz = pytorch3d.transforms.matrix_to_quaternion(rot_matrix)
                res_q_wxyz = torch.cat((quat_wxyz, res[:, -3:]), dim=-1)
                pred_pose_q_wxyz = res_q_wxyz.reshape(bs, repeat_num, -1)  

                average_pred_pose_q_wxyz = torch.zeros((bs, 7)).to(pred_pose_q_wxyz.device)
                average_pred_pose_q_wxyz[:, :4] = average_quaternion_batch(pred_pose_q_wxyz[:, :, :4])
                average_pred_pose_q_wxyz[:, 4:] = torch.mean(pred_pose_q_wxyz[:, :, 4:], dim=1)
                if return_process:
                    return pred_pose, pred_pose_q_wxyz, average_pred_pose_q_wxyz, in_process_sample
                else:
                    return pred_pose, pred_pose_q_wxyz, average_pred_pose_q_wxyz
            else:
                if return_process:
                    return [pred_pose, in_process_sample]
                else:
                    return pred_pose

    def pred_joint_func(self, data, repeat_num, save_path='./visualization_results', return_average_res=False, init_x=None,
                  T0=None, return_process=False):

        self.is_testing = True

        with torch.no_grad():
            pts_feat = data['pts_feat']
            
            
            data['pts_feat'] = pts_feat
            bs = data['pts'].shape[0]
            self.pts_feature = True

            ''' Repeat input data, [bs, ...] to [bs*repeat_num, ...] '''
            repeated_data = {}
            for key in data.keys():
                if key in ['seg', 'atc', 'file_name', 'type']:
                    continue
                data_shape = [item for item in data[key].shape]
                repeat_list = np.ones(len(data_shape) + 1, dtype=np.int8).tolist()
                repeat_list[1] = repeat_num
                repeated_data[key] = data[key].unsqueeze(1).repeat(repeat_list)
                data_shape[0] = bs * repeat_num
                repeated_data[key] = repeated_data[key].view(data_shape)  
            repeated_data['type'] = data['type']
            repeated_init_x = None if init_x is None else init_x.unsqueeze(1).repeat(1, repeat_num, 1).view(
                bs * repeat_num, -1)

            ''' Inference '''
            in_process_sample, res = self.sample(repeated_data, 'ode', init_x=repeated_init_x, T0=T0)
            pred_joint = res.reshape(bs, repeat_num, -1)
            in_process_sample = in_process_sample.reshape(bs, repeat_num, in_process_sample.shape[1], -1)

            self.pts_feature = False

            ''' Calculate the average results '''
            if return_average_res:
                average_pred_joint = torch.mean(pred_joint, dim=1)
                if return_process:
                    return pred_joint, average_pred_joint, in_process_sample
                else:
                    return pred_joint, average_pred_joint
            else:
                if return_process:
                    return [pred_joint, in_process_sample]
                else:
                    return pred_joint

    def sample(self, data, sampler, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
        self.device = device
        if sampler == 'pc':
            in_process_sample, res = cond_pc_sampler(
                score_model=self,
                data=data,
                prior=self.prior_fn,
                sde_coeff=self.sde_fn,
                num_steps=self.cfg.sampling_steps,
                snr=snr,
                device=self.device,
                eps=self.sampling_eps,
                pose_mode=self.specs['pose_mode'],
                init_x=init_x
            )

        elif sampler == 'ode':
            if data['type'] == 'pose':
                T0 = self.T if T0 is None else T0
                in_process_sample, res = pose_cond_ode_sampler(
                    score_model=self,
                    data=data,
                    prior=self.prior_fn,
                    sde_coeff=self.sde_fn,
                    atol=atol,
                    rtol=rtol,
                    device=self.device,
                    eps=self.sampling_eps,
                    T=T0,
                    num_steps=self.specs['sampling_steps'],
                    pose_mode=self.specs['pose_mode'],
                    denoise=denoise,
                    init_x=init_x
                )
            else:
                T0 = self.T if T0 is None else T0
                in_process_sample, res = joint_cond_ode_sampler(
                    score_model=self,
                    data=data,
                    prior=self.prior_fn,
                    sde_coeff=self.sde_fn,
                    atol=atol,
                    rtol=rtol,
                    device=self.device,
                    eps=self.sampling_eps,
                    T=T0,
                    num_steps=self.specs['sampling_steps'],
                    pose_mode=self.specs['pose_mode'],
                    denoise=denoise,
                    init_x=init_x
                )

        else:
            raise NotImplementedError

        return in_process_sample, res

class PoseDecoderNet(nn.Module):
    def __init__(self, marginal_prob_func, sigma_data=1.4148, pose_mode='quat_wxyz', regression_head='RT'):
        super(PoseDecoderNet, self).__init__()
        self.sigma_data = sigma_data
        self.regression_head = regression_head
        self.act = nn.ReLU(True)
        pose_dim = get_pose_dim(pose_mode)

        ''' encode pose '''
        self.pose_encoder = nn.Sequential(
            nn.Linear(pose_dim, 256),
            self.act,
            nn.Linear(256, 256),
            self.act,
        )

        ''' encode sigma(t) '''
        self.sigma_encoder = nn.Sequential(
            PositionalEmbedding(num_channels=128),
            nn.Linear(128, 128),
            self.act,
        )

        ''' fusion tail '''
        init_zero = dict(init_mode='kaiming_uniform', init_weight=0,
                         init_bias=0)  

        if self.regression_head == 'RT':
            self.fusion_tail = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 512),
                self.act,
                Linear(512, pose_dim, **init_zero),
            )

        elif self.regression_head == 'R_and_T':
            ''' rotation regress head '''
            self.fusion_tail_rot = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                Linear(256, pose_dim - 3, **init_zero),
            )

            ''' tranalation regress head '''
            self.fusion_tail_trans = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                Linear(256, 3, **init_zero),
            )

        elif self.regression_head == 'Rx_Ry_and_T':
            if pose_mode != 'rot_matrix':
                raise NotImplementedError
            ''' rotation_x_axis regress head '''
            self.fusion_tail_rot_x = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                Linear(256, 3, **init_zero),
            )
            self.fusion_tail_rot_y = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                Linear(256, 3, **init_zero),
            )

            ''' tranalation regress head '''
            self.fusion_tail_trans = nn.Sequential(
                nn.Linear(128 + 256 + 1024, 256),
                
                self.act,
                Linear(256, 3, **init_zero),
            )

        else:
            raise NotImplementedError

        self.marginal_prob_func = marginal_prob_func

    def forward(self, data):
        '''
        Args:
            data, dict {
                'pts_feat': [bs, c]
                'pose_sample': [bs, pose_dim]
                't': [bs, 1]
            }
        '''

        pts_feat = data['pts_feat']
        sampled_pose = data['sampled_pose']
        t = data['t']
        _, sigma_t = self.marginal_prob_func(None, t)
        
        c_skip = 1
        c_out = sigma_t
        c_in = 1
        c_noise = torch.log(sigma_t / 2)

        
        sampled_pose_rescale = sampled_pose * c_in
        pose_feat = self.pose_encoder(sampled_pose_rescale)
        sigma_feat = self.sigma_encoder(c_noise.squeeze(1))
        total_feat = torch.cat([pts_feat, sigma_feat, pose_feat], dim=-1)

        if self.regression_head == 'RT':
            nn_output = self.fusion_tail(total_feat)
        elif self.regression_head == 'R_and_T':
            rot = self.fusion_tail_rot(total_feat)
            trans = self.fusion_tail_trans(total_feat)
            nn_output = torch.cat([rot, trans], dim=-1)
        elif self.regression_head == 'Rx_Ry_and_T':
            rot_x = self.fusion_tail_rot_x(total_feat)
            rot_y = self.fusion_tail_rot_y(total_feat)
            trans = self.fusion_tail_trans(total_feat)
            nn_output = torch.cat([rot_x, rot_y, trans], dim=-1)
        else:
            raise NotImplementedError

        denoised_output = c_skip * sampled_pose + c_out * nn_output
        return denoised_output


